// Package lex2 实现了词法分析器
package lex2

import (
	"io"
	"math"
	"strings"
	"unsafe"

	"gitee.com/u-language/u-language/ucom/errcode"
)

// 一个文件的词法分析结果
type FileToken struct {
	// 文件名
	File string
	//源代码
	str    string
	errctx *errcode.ErrCtx
	// 未被扫描的最左边
	left int
	// 行号
	Line int
	//是否启用注释
	enableComments bool
}

// NewFileToken 创建一个文件的词法分析结果
//   - file是文件名
//   - src是源代码
//   - errctx是错误处理上下文
//   - enableComments设置是否启用注释
func NewFileToken(file string, src io.Reader, errctx *errcode.ErrCtx, enableComments bool) FileToken {
	SrcSlice, err := io.ReadAll(src)
	if err != nil {
		panic(err)
	}
	if !(len(SrcSlice) < math.MaxInt) {
		panic("过大的字符串")
	}
	return FileToken{File: file, str: string(SrcSlice), errctx: errctx, Line: 1, left: -1, enableComments: enableComments}
}

// Next 返回下一个未被识别的 [Token]
func (l *FileToken) Next() Token {
	l.left++
	for l.left < len(l.str) {
		b := getByte(l.str, l.left, "l.left >=0 && l.left < len(l.str)")
	selector:
		switch b {
		case '\n':
			l.Line++
			if l.left == len(l.str)-1 {
				return NewToken(EOF, "")
			}
			return NewToken(NewLine, "\n")
		case '\r': // /r除非后有/n，否则忽略
			if l.left+1 >= len(l.str) { //如果到位末尾，视为结束
				return NewToken(EOF, "")
			}
			if getByte(l.str, l.left+1, "l.left >=0 && l.left+1 < len(l.str)") == '\n' {
				l.left++
				l.Line++
				return NewToken(NewLine, "\n")
			}
			l.left++
		case ' ', '\t':
			for i := l.left + 1; i < len(l.str); i++ { //连续的空格或制表符跳过
				b := getByte(l.str, i, "i > 0 && i < len(l.str)")
				if b != ' ' && b != '\t' {
					l.left = i
					break selector
				}
			}
			l.left++ //如果只有一个空格或制表符
		case '(':
			return NewToken(LPAREN, "(")
		case ')':
			return NewToken(RPAREN, ")")
		case '{':
			return NewToken(LBRACE, "{")
		case '}':
			return NewToken(RBRACE, "}")
		case '[':
			return NewToken(LBRACK, "[")
		case ']':
			return NewToken(RBRACK, "]")
		case ',':
			return NewToken(Comma, ",")
		case ':':
			return NewToken(Colon, ":")
		case '<':
			return NewToken(Less, "<")
		case '>':
			return NewToken(Greater, ">")
		case ';':
			return NewToken(SEMICOLON, ";")
		case '%':
			return NewToken(Remain, "%")
		case '^':
			return NewToken(Xor, "^")
		case '.':
			return NewToken(PERIOD, ".")
		case '@':
			return NewToken(Deref, "@")
		case '=':
			if l.left+1 >= len(l.str) {
				return NewToken(ASSIGN, "=")
			}
			if getByte(l.str, l.left+1, "l.left >=0 && l.left+1 < len(l.str)") == '=' {
				l.left++
				return NewToken(Equal, "==")
			}
			return NewToken(ASSIGN, "=")
		case '+':
			if l.left+1 >= len(l.str) {
				return NewToken(ADD, "+")
			}
			if getByte(l.str, l.left+1, "l.left >=0 && l.left+1 < len(l.str)") == '+' {
				l.left++
				return NewToken(Inc, "++")
			}
			return NewToken(ADD, "+")
		case '-':
			if l.left+1 >= len(l.str) {
				return NewToken(SUB, "-")
			}
			if getByte(l.str, l.left+1, "l.left >=0 && l.left+1 < len(l.str)") == '-' {
				l.left++
				return NewToken(Dec, "--")
			}
			return NewToken(SUB, "-")
		case '*':
			return NewToken(MUL, "*")
		case '/':
			if l.left+1 >= len(l.str) {
				return NewToken(DIV, "/")
			}
			switch getByte(l.str, l.left+1, "l.left >=0 && l.left+1 < len(l.str)") {
			case '*': //如果是/*，处理多行注释
				for i := l.left + 2; i < len(l.str); i++ {
					switch getByte(l.str, i, "i > 0 && i < len(l.str)") {
					case '\n':
						l.Line++
					case '*':
						if i+1 < len(l.str) && getByte(l.str, i+1, "i+1 > 0 && i+1 < len(l.str)") == '/' { //注释到*/
							old := l.left
							l.left = i + 2
							if !l.enableComments { //如果没启用注释
								break selector
							}
							return NewToken(MLC, getStr(l.str, old, l.left, "l.left 不会导致访问超出字符串末尾的内存"))
						}
					}
				}
				//执行到这里说明多行注释没有结束
				l.left = len(l.str)
				l.errctx.Panic(l.File, l.Line, nil, errcode.MLCNoEnd)
				break selector
			case '/': //如果是//，处理单行注释
				old := l.left
				for i := l.left + 2; i < len(l.str); i++ {
					if getByte(l.str, i, "i > 0 && i < len(l.str)") == '\n' { //注释到一行结束
						l.left = i
						l.Line++
						if !l.enableComments { //如果没启用注释
							return NewToken(NewLine, "\n")
						}
						return NewToken(SinLineComments, getStr(l.str, old, l.left+1, "old > 0 && i < len(l.str)"))
					}
				}
				l.left = len(l.str)
				if !l.enableComments { //如果没启用注释，执行到这里说明字符串扫描完了
					return NewToken(EOF, "")
				}
				return NewToken(SinLineComments, getStr(l.str, old, len(l.str), "old > 0"))
			}
			return NewToken(DIV, "/")
		case '!':
			if l.left+1 >= len(l.str) {
				goto err
			}
			if getByte(l.str, l.left+1, "l.left >=0 && l.left+1 < len(l.str)") == '=' {
				l.left++
				return NewToken(NoEqual, "!=")
			}
		err:
			l.errctx.Panic(l.File, l.Line, errcode.NewMsgUnexpected("!"))
			return NewToken(EOF, "")
		case '|':
			if l.left+1 >= len(l.str) {
				return NewToken(Or, "|")
			}
			if getByte(l.str, l.left+1, "l.left >=0 && l.left+1 < len(l.str)") == '|' {
				l.left++
				return NewToken(LogicOr, "||")
			}
			return NewToken(Or, "|")
		case '&':
			if l.left+1 >= len(l.str) {
				return NewToken(LEA, "&")
			}
			if getByte(l.str, l.left+1, "l.left >=0 && l.left+1 < len(l.str)") == '&' {
				l.left++
				return NewToken(LogicAND, "&&")
			}
			return NewToken(LEA, "&")
		case '"':
			for i := l.left + 1; i < len(l.str); i++ {
				b := getByte(l.str, i, "i > 0 && i < len(l.str)")
				if b == '\n' { //字符串应该在一行
					goto err2
				}
				if b == '"' && getByte(l.str, i-1, "i-1 >= 0 && i-1 < len(l.str)") != '\\' {
					old := l.left
					l.left = i
					return NewToken(String, getStr(l.str, old, i+1, "old > 0 && i+1不会导致访问超出字符串的内存"))
				}
			}
		err2:
			l.errctx.Panic(l.File, l.Line, nil, errcode.StringNoEnd)
			return NewToken(EOF, "")
		case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
			decimal_point := 0
			i := l.left + 1
		num_loop:
			for ; i < len(l.str); i++ {
				switch getByte(l.str, i, "i >=0 && i < len(l.str)") {
				case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
				case '.':
					decimal_point++
					if decimal_point != 1 { //不止一个小数点
						l.errctx.Panic(l.File, l.Line, errcode.NewMsgUnexpected("."))
						return NewToken(EOF, "")
					}
				default:
					break num_loop
				}
			}
			// 如果只有一个数字
			old := l.left
			l.left = i - 1
			kind := Int
			if decimal_point == 1 {
				kind = FLOAT
			}
			return NewToken(kind, getStr(l.str, old, i, "old > 0"))
		default:
			i := l.left + 1
		loop_end:
			for ; i < len(l.str); i++ {
				b := getByte(l.str, i, "i > 0 && i < len(l.str)")
				switch b {
				case '\r', '\n':
					break loop_end
				default:
					if hashSep(b) {
						break loop_end
					}
				}
			}
			old := l.left
			l.left = i - 1
			return l.parserName(getStr(l.str, old, i, "old > 0 && l.left 不会访问超出字符串的内存"))
		}
	}
	return NewToken(EOF, "")
}

type str struct {
	ptr unsafe.Pointer
	len int
}

func getStr(s string, start int, end int, whysafe string) string {
	sp := (*str)(unsafe.Pointer(&s))
	ret := str{ptr: unsafe.Add(sp.ptr, start), len: end - start}
	return *(*string)(unsafe.Pointer(&ret))
}

func getByte(s string, index int, whysafe string) byte {
	sp := (*str)(unsafe.Pointer(&s))
	return *(*byte)(unsafe.Add(sp.ptr, index))
}

func hashSep(s byte) bool {
	return seqHashMap[s]
}

func init() {
	for _, v := range []byte{
		' ', '\t', ';', //空格或制表符或分号
		'+', '-', '*', '/', '&', '@', '!', '=', '|', '%', '^', '<', '>', //运算符
		'(', ')', '{', '}', '[', ']', //括号
		'"', ',', '.', ':',
	} {
		seqHashMap[v] = true
	}
}

var seqHashMap [257]bool

func (l *FileToken) parserName(s string) Token {
	if len(s) >= 2 {
		if kind := hashMap[hash(s)]; kind.TYPE != No && kind.Value == s {
			return kind
		}
	}
	return NewToken(NAME, s) //如果其他规则不能成功分析出单词标记，分析为符号
}

func (l *FileToken) String() string {
	var buf strings.Builder
	l.left = -1
	l.Line = 0
	buf.Grow(len(l.str))
	buf.WriteString("{")
	buf.WriteString("\t\n")
	for tk := l.Next(); tk.TYPE != EOF; tk = l.Next() {
		if tk.TYPE == NewLine {
			buf.WriteString("\n")
			continue
		}
		buf.WriteString(tk.String())
		buf.WriteString("\t")
	}
	l.left = -1
	l.Line = 0
	buf.WriteString("}")
	return buf.String()
}

// EOFIsNewLine 返回是否是以换行作为文件结束
// 如果前面没有token，返回true
func (l *FileToken) EOFIsNewLine() bool {
	if len(l.str) == 0 {
		return false
	}
	return l.str[len(l.str)-1] == '\n'
}
